{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第9章 EM算法及其推广"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Expectation Maximization algorithm\n",
    "\n",
    "### Maximum likehood function\n",
    "\n",
    "[likehood & maximum likehood](http://fangs.in/post/thinkstats/likelihood/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1．EM算法是含有隐变量的概率模型极大似然估计或极大后验概率估计的迭代算法。含有隐变量的概率模型的数据表示为$\\theta$ )。这里，$Y$是观测变量的数据，$Z$是隐变量的数据，$\\theta$ 是模型参数。EM算法通过迭代求解观测数据的对数似然函数${L}(\\theta)=\\log {P}(\\mathrm{Y} | \\theta)$的极大化，实现极大似然估计。每次迭代包括两步：\n",
    "\n",
    "$E$步，求期望，即求$logP\\left(Z | Y, \\theta\\right)$ )关于$ P\\left(Z | Y, \\theta^{(i)}\\right)$)的期望：\n",
    "\n",
    "$$Q\\left(\\theta, \\theta^{(i)}\\right)=\\sum_{Z} \\log P(Y, Z | \\theta) P\\left(Z | Y, \\theta^{(i)}\\right)$$\n",
    "称为$Q$函数，这里$\\theta^{(i)}$是参数的现估计值；\n",
    "\n",
    "$M$步，求极大，即极大化$Q$函数得到参数的新估计值：\n",
    "\n",
    "$$\\theta^{(i+1)}=\\arg \\max _{\\theta} Q\\left(\\theta, \\theta^{(i)}\\right)$$\n",
    " \n",
    "在构建具体的EM算法时，重要的是定义$Q$函数。每次迭代中，EM算法通过极大化$Q$函数来增大对数似然函数${L}(\\theta)$。\n",
    "\n",
    "2．EM算法在每次迭代后均提高观测数据的似然函数值，即\n",
    "\n",
    "$$P\\left(Y | \\theta^{(i+1)}\\right) \\geqslant P\\left(Y | \\theta^{(i)}\\right)$$\n",
    "\n",
    "在一般条件下EM算法是收敛的，但不能保证收敛到全局最优。\n",
    "\n",
    "3．EM算法应用极其广泛，主要应用于含有隐变量的概率模型的学习。高斯混合模型的参数估计是EM算法的一个重要应用，下一章将要介绍的隐马尔可夫模型的非监督学习也是EM算法的一个重要应用。\n",
    "\n",
    "4．EM算法还可以解释为$F$函数的极大-极大算法。EM算法有许多变形，如GEM算法。GEM算法的特点是每次迭代增加$F$函数值（并不一定是极大化$F$函数），从而增加似然函数值。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 在统计学中，似然函数（likelihood function，通常简写为likelihood，似然）是一个非常重要的内容，在非正式场合似然和概率（Probability）几乎是一对同义词，但是在统计学中似然和概率却是两个不同的概念。概率是在特定环境下某件事情发生的可能性，也就是结果没有产生之前依据环境所对应的参数来预测某件事情发生的可能性，比如抛硬币，抛之前我们不知道最后是哪一面朝上，但是根据硬币的性质我们可以推测任何一面朝上的可能性均为50%，这个概率只有在抛硬币之前才是有意义的，抛完硬币后的结果便是确定的；而似然刚好相反，是在确定的结果下去推测产生这个结果的可能环境（参数），还是抛硬币的例子，假设我们随机抛掷一枚硬币1,000次，结果500次人头朝上，500次数字朝上（实际情况一般不会这么理想，这里只是举个例子），我们很容易判断这是一枚标准的硬币，两面朝上的概率均为50%，这个过程就是我们运用出现的结果来判断这个事情本身的性质（参数），也就是似然。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$P(Y|\\theta) = \\prod[\\pi p^{y_i}(1-p)^{1-y_i}+(1-\\pi) q^{y_i}(1-q)^{1-y_i}]$$\n",
    "\n",
    "### E step:\n",
    "\n",
    "$$\\mu^{i+1}=\\frac{\\pi (p^i)^{y_i}(1-(p^i))^{1-y_i}}{\\pi (p^i)^{y_i}(1-(p^i))^{1-y_i}+(1-\\pi) (q^i)^{y_i}(1-(q^i))^{1-y_i}}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "pro_A, pro_B, por_C = 0.5, 0.5, 0.5\n",
    "\n",
    "\n",
    "def pmf(i, pro_A, pro_B, por_C):\n",
    "    pro_1 = pro_A * math.pow(pro_B, data[i]) * math.pow(\n",
    "        (1 - pro_B), 1 - data[i])\n",
    "    pro_2 = pro_A * math.pow(pro_C, data[i]) * math.pow(\n",
    "        (1 - pro_C), 1 - data[i])\n",
    "    return pro_1 / (pro_1 + pro_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### M step:\n",
    "\n",
    "$$\\pi^{i+1}=\\frac{1}{n}\\sum_{j=1}^n\\mu^{i+1}_j$$\n",
    "\n",
    "$$p^{i+1}=\\frac{\\sum_{j=1}^n\\mu^{i+1}_jy_i}{\\sum_{j=1}^n\\mu^{i+1}_j}$$\n",
    "\n",
    "$$q^{i+1}=\\frac{\\sum_{j=1}^n(1-\\mu^{i+1}_jy_i)}{\\sum_{j=1}^n(1-\\mu^{i+1}_j)}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EM:\n",
    "    def __init__(self, prob):\n",
    "        self.pro_A, self.pro_B, self.pro_C = prob\n",
    "\n",
    "    # e_step\n",
    "    def pmf(self, i):\n",
    "        pro_1 = self.pro_A * math.pow(self.pro_B, data[i]) * math.pow(\n",
    "            (1 - self.pro_B), 1 - data[i])\n",
    "        pro_2 = (1 - self.pro_A) * math.pow(self.pro_C, data[i]) * math.pow(\n",
    "            (1 - self.pro_C), 1 - data[i])\n",
    "        return pro_1 / (pro_1 + pro_2)\n",
    "\n",
    "    # m_step\n",
    "    def fit(self, data):\n",
    "        count = len(data)\n",
    "        print('init prob:{}, {}, {}'.format(self.pro_A, self.pro_B,\n",
    "                                            self.pro_C))\n",
    "        for d in range(count):\n",
    "            _ = yield\n",
    "            _pmf = [self.pmf(k) for k in range(count)]\n",
    "            pro_A = 1 / count * sum(_pmf)\n",
    "            pro_B = sum([_pmf[k] * data[k] for k in range(count)]) / sum(\n",
    "                [_pmf[k] for k in range(count)])\n",
    "            pro_C = sum([(1 - _pmf[k]) * data[k]\n",
    "                         for k in range(count)]) / sum([(1 - _pmf[k])\n",
    "                                                        for k in range(count)])\n",
    "            print('{}/{}  pro_a:{:.3f}, pro_b:{:.3f}, pro_c:{:.3f}'.format(\n",
    "                d + 1, count, pro_A, pro_B, pro_C))\n",
    "            self.pro_A = pro_A\n",
    "            self.pro_B = pro_B\n",
    "            self.pro_C = pro_C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data=[1,1,0,1,0,0,1,0,1,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init prob:0.5, 0.5, 0.5\n"
     ]
    }
   ],
   "source": [
    "em = EM(prob=[0.5, 0.5, 0.5])\n",
    "f = em.fit(data)\n",
    "next(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/10  pro_a:0.500, pro_b:0.600, pro_c:0.600\n"
     ]
    }
   ],
   "source": [
    "# 第一次迭代\n",
    "f.send(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/10  pro_a:0.500, pro_b:0.600, pro_c:0.600\n"
     ]
    }
   ],
   "source": [
    "# 第二次\n",
    "f.send(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init prob:0.4, 0.6, 0.7\n"
     ]
    }
   ],
   "source": [
    "em = EM(prob=[0.4, 0.6, 0.7])\n",
    "f2 = em.fit(data)\n",
    "next(f2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/10  pro_a:0.406, pro_b:0.537, pro_c:0.643\n"
     ]
    }
   ],
   "source": [
    "f2.send(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/10  pro_a:0.406, pro_b:0.537, pro_c:0.643\n"
     ]
    }
   ],
   "source": [
    "f2.send(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第9章EM算法及其推广-习题\n",
    "### 习题9.1\n",
    "\n",
    "&emsp;&emsp;如例9.1的三硬币模型，假设观测数据不变，试选择不同的处置，例如，$\\pi^{(0)}=0.46,p^{(0)}=0.55,q^{(0)}=0.67$，求模型参数为$\\theta=(\\pi,p,q)$的极大似然估计。  \n",
    "\n",
    "**解答：**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import math\n",
    "\n",
    "class EM:\n",
    "    def __init__(self, prob):\n",
    "        self.pro_A, self.pro_B, self.pro_C = prob\n",
    "\n",
    "    def pmf(self, i):\n",
    "        pro_1 = self.pro_A * math.pow(self.pro_B, data[i]) * math.pow(\n",
    "            (1 - self.pro_B), 1 - data[i])\n",
    "        pro_2 = (1 - self.pro_A) * math.pow(self.pro_C, data[i]) * math.pow(\n",
    "            (1 - self.pro_C), 1 - data[i])\n",
    "        return pro_1 / (pro_1 + pro_2)\n",
    "\n",
    "    def fit(self, data):\n",
    "        print('init prob:{}, {}, {}'.format(self.pro_A, self.pro_B,\n",
    "                                            self.pro_C))\n",
    "        count = len(data)\n",
    "        theta = 1\n",
    "        d = 0\n",
    "        while (theta > 0.00001):\n",
    "            # 迭代阻塞\n",
    "            _pmf = [self.pmf(k) for k in range(count)]\n",
    "            pro_A = 1 / count * sum(_pmf)\n",
    "            pro_B = sum([_pmf[k] * data[k] for k in range(count)]) / sum(\n",
    "                [_pmf[k] for k in range(count)])\n",
    "            pro_C = sum([(1 - _pmf[k]) * data[k]\n",
    "                         for k in range(count)]) / sum([(1 - _pmf[k])\n",
    "                                                        for k in range(count)])\n",
    "            d += 1\n",
    "            print('{}  pro_a:{:.4f}, pro_b:{:.4f}, pro_c:{:.4f}'.format(\n",
    "                d, pro_A, pro_B, pro_C))\n",
    "            theta = abs(self.pro_A - pro_A) + abs(self.pro_B -\n",
    "                                                  pro_B) + abs(self.pro_C -\n",
    "                                                               pro_C)\n",
    "            self.pro_A = pro_A\n",
    "            self.pro_B = pro_B\n",
    "            self.pro_C = pro_C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init prob:0.46, 0.55, 0.67\n",
      "1  pro_a:0.4619, pro_b:0.5346, pro_c:0.6561\n",
      "2  pro_a:0.4619, pro_b:0.5346, pro_c:0.6561\n"
     ]
    }
   ],
   "source": [
    "# 加载数据\n",
    "data = [1, 1, 0, 1, 0, 0, 1, 0, 1, 1]\n",
    "\n",
    "em = EM(prob=[0.46, 0.55, 0.67])\n",
    "f = em.fit(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可见通过两次迭代，参数已经收敛，三个硬币的概率分别为0.4619，0.5346，0.6561"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 习题9.2\n",
    "证明引理9.2。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> **引理9.2：**若$\\tilde{P}_{\\theta}(Z)=P(Z | Y, \\theta)$，则$$F(\\tilde{P}, \\theta)=\\log P(Y|\\theta)$$\n",
    "\n",
    "**证明：**  \n",
    "由$F$函数的定义（**定义9.3**）可得：$$F(\\tilde{P}, \\theta)=E_{\\tilde{P}}[\\log P(Y,Z|\\theta)] + H(\\tilde{P})$$其中，$H(\\tilde{P})=-E_{\\tilde{P}} \\log \\tilde{P}(Z)$  \n",
    "$\\begin{aligned}\n",
    "\\therefore F(\\tilde{P}, \\theta) \n",
    "&= E_{\\tilde{P}}[\\log P(Y,Z|\\theta)] -E_{\\tilde{P}} \\log \\tilde{P}(Z) \\\\\n",
    "&= \\sum_Z \\log P(Y,Z|\\theta) \\tilde{P}_{\\theta}(Z) - \\sum_Z \\log \\tilde{P}(Z) \\cdot \\tilde{P}(Z) \\\\\n",
    "&= \\sum_Z \\log P(Y,Z|\\theta) P(Z|Y,\\theta) -  \\sum_Z \\log P(Z|Y,\\theta) \\cdot P(Z|Y,\\theta) \\\\\n",
    "&= \\sum_Z P(Z|Y,\\theta) \\left[ \\log P(Y,Z|\\theta) - \\log P(Z|Y,\\theta) \\right] \\\\\n",
    "&= \\sum_Z P(Z|Y,\\theta) \\log \\frac{P(Y,Z|\\theta)}{P(Z|Y,\\theta)} \\\\\n",
    "&= \\sum_Z P(Z|Y,\\theta) \\log P(Y|\\theta) \\\\\n",
    "&= \\log P(Y|\\theta) \\sum_Z P(Z|Y,\\theta)\n",
    "\\end{aligned}$  \n",
    "$\\displaystyle \\because \\sum_Z \\tilde{P}_{\\theta}(Z) = P(Z|Y, \\theta) = 1$  \n",
    "$\\therefore F(\\tilde{P}, \\theta) = \\log P(Y|\\theta)$，引理9.2得证。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 习题9.3\n",
    "已知观测数据  \n",
    "-67，-48，6，8，14，16，23，24，28，29，41，49，56，60，75  \n",
    "试估计两个分量的高斯混合模型的5个参数。\n",
    "\n",
    "**解答：**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "labels = [1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.mixture import GaussianMixture\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 初始化观测数据\n",
    "data = np.array([-67, -48, 6, 8, 14, 16, 23, 24, 28, 29, 41, 49, 56, 60,\n",
    "                 75]).reshape(-1, 1)\n",
    "\n",
    "# 聚类\n",
    "gmmModel = GaussianMixture(n_components=2)\n",
    "gmmModel.fit(data)\n",
    "labels = gmmModel.predict(data)\n",
    "print(\"labels =\", labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEWCAYAAACNJFuYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAZ40lEQVR4nO3de5RlZX3m8e8jpSLYCKSb5tJoo2IiIaQlJQM6YRKRiIrgmkknGC9EnZAYb0nMiMRZSXTFeEmWxkxudkTFSDQ03tqMJiJqmEREG20vgEpHIrRAU0aB9gaW/OaPvQsPvU91FXSf2nU4389avc7Zl3P271RVn2e/77svqSokSRp0n74LkCQtP4aDJKnDcJAkdRgOkqQOw0GS1GE4SJI6DAfdKyX5UJIze9z+g5N8O8lefdWw1JL8XJJti1z3D5O8Y9Q16Z4zHLRHJDkjyWVJvpPkpvb5byZJH/VU1ROr6rw9/b5JfjVJJXn9TvOf2s5/W7v9a6vqgVX1w0W859uS/NGernWBbVaS7UmmBuZNtb87T36S4aDdl+QlwBuBPwEOBlYDvwE8Frhfj6WNyr8Dvzz4xQo8C/hKH8XsVMfdcTPwxIHpJwHf2v2KdG9gOGi3JHkQ8ErgN6vqwqraUY3PVtXTq+q2dr0nJ/lskluTXJfkDwfeo9MdkeQ/kjy+fX5cks3ta7fP7bUn2TvJO5L8Z5Kbk3w6yep22ceT/M/2+cOSfLRd7xtJzk+y/07b+t0kn09yS5J/SLL3Lj72jcAXgCe0rz8QeAywaeA917Z751NJDkyyLclT2mUPTLI1ybOSnAU8HXhp2w31gXadSvLwgfe7s3Ux9/NKcnaSG4G3tvNPTbKl/Vl8IskxC/z6/o4m1OY8C3j7Tr+HQ5NsSvLNtuZfG1j2gLaubyW5Enj0kNe+O8lMkmuSvGiBerSMGA7aXScA9wfev8B636H58tkfeDLwvCRPXeQ23gi8sar2Ax4GXNDOPxN4EHA48GM0rZXvDXl9gFcDhwKPbNf/w53W+SXgFOAI4BjgVxeo6e386Iv1DJrPf9uwFavqm8BzgL9NchDwBmBLVb29qjYA5wOva7uhnrLAduccDBwIPAQ4K8mxwFuAX6f5WbwJ2JTk/rt4j/cBJybZvw3Ln6X7e3wnsI3mZ/eLwB8nOald9gc0v4+H0QTlnWM8Se4DfAD4HHAYcBLwW0mesMjPp54ZDtpdK4FvVNXs3Ix2r/XmJN9LciJAVX28qr5QVXdU1edpvnT+2yK38QPg4UlWVtW3q+qTA/N/DHh4Vf2wqi6vqlt3fnFVba2qi6rqtqqaAV4/ZNt/XlXXt1/kHwDWLVDTe4Gfa1tOnT3uITV8GNgIXEwTjr++wPsv5A7gD9rP9D3g14A3VdVl7c/iPJqwOn4X7/F9ms/6yzQBt6mdB0CSw4H/CpxdVd+vqi3Am4Fntqv8EvCqqvpmVV0H/PnAez8aWFVVr6yq26vqq8DfttvRGDActLv+E1g52O9dVY+pqv3bZfcBSPJfknys7WK4hWYvf+Uit/Fc4BHAl9quo1Pb+X8H/DPwriTXJ3ldkvvu/OIkByV5V5KvJ7kVeMeQbd848Py7wAN3VVD7hfx/gf8NrKyqf1vE59gAHA28tar+cxHr78pMVX1/YPohwEvaUL45yc00LaRDF3ifuRbQsIA7FPhmVe0YmPc1mpbA3PLrdlo2WM+hO9XzezTjURoDhoN216U0e6inL7De39PsmR5eVQ8C/oamuweaLqd95lZMc/jnqrnpqrq6qp4GHAS8Frgwyb5V9YOqekVVHUXT538qd+1Dn/NqoIBj2q6pZwxse3e8HXgJTUjtUvuZ3tS+5nmD4wltbTv7LgM/E5pupEE7v+Y6mr34/Qf+7VNV71ygtP8HHELzpf2vOy27HjgwyYqBeQ8Gvt4+v4EmgAaXDdZzzU71rKiqJy1Qj5YJw0G7papuBl4B/FWSX2wHW++TZB2w78CqK2j2Qr+f5DjgVwaWfQXYux20vi/N3vidfeVJnpFkVVXdQXOEDcAPk/x8kp9qv3hvpelmGnbo6Arg28DNSQ4D/tce+fDwL8DJwP9ZxLq/1z4+B/hT4O350TkQ24GH7rT+FuBXkuyV5BQW7oL7W+A32hZakuzb/jxX7OpF1Vyz/ynAabXT9fvbrqJPAK9uB/+PoWnFnd+ucgFwTpIDkqwBXjjw8k8Bt7aD5g9oP8fRSe4yaK3ly3DQbquq1wG/A7wUuInmy+5NwNk0Xy4Avwm8MskO4Pf50aAyVXVLu/zNNHul36EZBJ1zCnBFkm/TDE6f0XapHAxcSBMMV9F8WQ87seoVwLHALTRdQe/Z7Q/d1F1VdXE7TjGvJD9D8/N5Vnvew2tp9vxf1q5yLnBU2/3yvnbei2m+tG+mOZrpfexCVW2mGXf4C5rDUbey8KD63GuvqKor5ln8NGAtTSvivTTjHBe1y15B05V0DfBhBlpQ7ed8Cs3YzTXAN2h+vw9aTE3qX7zZjyRpZ7YcJEkdhoMkqcNwkCR1GA6SpI57esGuZWXlypW1du3avsuQpLFy+eWXf6OqVg1bdq8Ih7Vr17J58+a+y5CksZLka/Mt67VbKclvJ7kiyReTvLM90eaINPcCuDrN1THvjZd8lqRlrbdwaM9UfREwXVVHA3vRXJTrtcAbqupImpN5nttXjZI0qfoekJ4CHtBetG0fmmu1PI7mrFeA84DFXtZZkrSH9BYOVfV1mmvMXEsTCrcAlwM3D1z+eRs/ugLkXSQ5K80NYDbPzMwsRcmSNDH67FY6gOZKnkfQXPp3X+56y8I5Q6/vUVUbqmq6qqZXrRo62C5Juof67FZ6PM0lfWeq6gc0F0N7DLD/wL0B1tBc8EuStIT6DIdrgeOT7JMkNLcRvBL4GM3tCKG57eBCt5+UJO1hfY45XEYz8PwZmpu134fmTllnA7+TZCvNLSDP7atGSZpUvR6tVFV/UFU/UVVHV9Uz2/vhfrWqjquqh1fV+qoaetN2SZpI27fDiSfCfvs1j9u3j2QzfR/KKkm6O9avh0svhR07msf160eyGcNBksbJli0w2x7tPzvbTI+A4SBJ42TdOphqD+icmmqmR8BwkKRxsnEjnHACrFjRPG7cOJLN3CuuyipJE2P1arjkkpFvxpaDJKnDcJAkdRgOkqQOw0GS1GE4SJI6DAdJUofhIEnqMBwkSR2GgySpw3CQJHUYDpKkDsNBktRhOEiSOgwHSVJHr+GQZP8kFyb5UpKrkpyQ5MAkFyW5un08oM8aJWkS9d1yeCPwT1X1E8BPA1cBLwMurqojgYvbaUkaL9u3w4knwn77NY/bt/dd0d3SWzgk2Q84ETgXoKpur6qbgdOB89rVzgOe2k+FkrQb1q+HSy+FHTuax/Xr+67obumz5fBQYAZ4a5LPJnlzkn2B1VV1A0D7eNCwFyc5K8nmJJtnZmaWrmpJWowtW2B2tnk+O9tMj5E+w2EKOBb466p6FPAd7kYXUlVtqKrpqppetWrVqGqUpHtm3TqYau/EPDXVTI+RPsNhG7Ctqi5rpy+kCYvtSQ4BaB9v6qk+SbrnNm6EE06AFSuax40b+67obuktHKrqRuC6JD/ezjoJuBLYBJzZzjsTeH8P5UmaJKMYPF69Gi65BG69tXlcvXr333MJTfW8/RcC5ye5H/BV4Nk0gXVBkucC1wLjNYojafzMDR7Pzv5o8PiSS/quqle9hkNVbQGmhyw6aalrkTTBxnzweBT6Ps9Bkvo35oPHo2A4SNKYDx6PQt9jDpLUv7nBY93JloMkqcNwkCR1GA6SpA7DQZLUYThIGh9jfhnscWI4SBofY34Z7HFiOEgaH57JvGQMB0njwzOZl4zhIGl8eCbzkvEMaUnjwzOZl4wtB0lSh+EgSeowHCRJHYaDJKnDcJAkdfQeDkn2SvLZJP/YTh+R5LIkVyf5h/b+0pLGjZe6GGu9hwPwYuCqgenXAm+oqiOBbwHP7aUqSbvHS12MtV7DIcka4MnAm9vpAI8DLmxXOQ94aj/VSdotXupirPXdcvgz4KXAHe30jwE3V1X7F8U24LBhL0xyVpLNSTbPzMyMvlJJd4+XuhhrvYVDklOBm6rq8sHZQ1atYa+vqg1VNV1V06tWrRpJjZJ2g5e6GGt9Xj7jscBpSZ4E7A3sR9OS2D/JVNt6WANc32ONku4pL3Ux1nprOVTVOVW1pqrWAmcAH62qpwMfA36xXe1M4P09lShJE6vvMYdhzgZ+J8lWmjGIc3uuR5ImzrK4KmtVfRz4ePv8q8BxfdYjSZNuObYcJEk9MxwkSR2GgySpw3CQJHUYDpKkDsNBktRhOEiSOgwHadJ53wUNYThIk877LmgIw0EaF6Paw/e+CxrCcJDGxaj28L3vgoYwHKRxMao9fO+7oCGWxYX3JC3CunVNi2F2ds/u4XvfBQ1hy0EaF+7hawkZDtIojGLweG4P/9Zbm8fVq3f/PaV5GA7SKHh4qMac4aDJ5uGh0lCGgyabh4dKQxkOmmweHioN1Vs4JDk8yceSXJXkiiQvbucfmOSiJFe3jwf0VaMmwKj28B081pjrs+UwC7ykqh4JHA88P8lRwMuAi6vqSODidloaDffwpaF6Owmuqm4Abmif70hyFXAYcDrwc+1q5wEfB87uoURNAk8Ak4ZaFmMOSdYCjwIuA1a3wTEXIAfN85qzkmxOsnlmZmapSpWkidB7OCR5IPBu4Leq6tbFvq6qNlTVdFVNr1q1anQFStIE6jUcktyXJhjOr6r3tLO3JzmkXX4IcFNf9UnSpOrzaKUA5wJXVdXrBxZtAs5sn58JvH+pa9My5R3LpCXTZ8vhscAzgccl2dL+exLwGuDkJFcDJ7fTkpekkJZQn0cr/SuQeRaftJS1aEx4SQppyfQ+IK17oVF1/3hJCmnJGA7a80bV/eMJa9KS8U5w2vNG1f3jCWvSkrHloD3P7h9p7BkO2vPs/pHGnt1K2vPs/pHGni0HSVKH4SBJ6jAcJpmXo5A0D8Nhknk5CknzMBwmmZejkDQPw2FcjKILyPMRJM3DcBgXo+gC8nwESfPwPIdxMYouIM9HkDSPBVsOSV6Q5IClKEa7YBeQpCW0mG6lg4FPJ7kgySntHdy01OwCkrSEUlULr9QEwi8AzwamgQuAc6vq30db3uJMT0/X5s2b+y5DksZKksuranrYskUNSFeTIDe2/2aBA4ALk7xuj1UpSVo2FjPm8KIklwOvA/4N+Kmqeh7wM8D/GFVhbRfWl5NsTfKyUW1HktS1mKOVVgL/vaq+Njizqu5IcuooikqyF/CXwMnANpoxj01VdeUotidJuqsFWw5V9fs7B8PAsqv2fEkAHAdsraqvVtXtwLuA00e0LUnSTpbrSXCHAdcNTG9r590pyVlJNifZPDMzs6TFSdK93XINh2GHy97lsKqq2lBV01U1vWrVqiUqS5Imw3INh23A4QPTa4Dre6pFkibOcg2HTwNHJjkiyf2AM4BNPdckSRNjWV5bqapmk7wA+GdgL+AtVXVFz2VJ0sRYluEAUFUfBD7Ydx2SNImWa7eSJKlHhoMkqcNwkCR1GA6SpA7DQZLUYThIkjoMB0lSh+EgSeowHCRJHYaDJKnDcJAkdRgOkqQOw0GS1GE4SJI6DAdJUofhIEnqMBwkSR2GgySpw3CQJHX0Eg5J/iTJl5J8Psl7k+w/sOycJFuTfDnJE/qoT5ImXV8th4uAo6vqGOArwDkASY4CzgB+EjgF+Kske/VUoyRNrF7Coao+XFWz7eQngTXt89OBd1XVbVV1DbAVOK6PGiVpki2HMYfnAB9qnx8GXDewbFs7ryPJWUk2J9k8MzMz4hIlabJMjeqNk3wEOHjIopdX1fvbdV4OzALnz71syPo17P2ragOwAWB6enroOpKke2Zk4VBVj9/V8iRnAqcCJ1XV3Jf7NuDwgdXWANePpkJJ0nz6OlrpFOBs4LSq+u7Aok3AGUnun+QI4EjgU33UKEmTbGQthwX8BXB/4KIkAJ+sqt+oqiuSXABcSdPd9Pyq+mFPNUrSxOolHKrq4btY9irgVUtYjiRpJ8vhaCVJ0jJjOEiSOgwHSVKH4SBJ6jAcJEkdhoMkqcNwkCR1GA6SpA7DQZLUYThIkjoMB0lSh+EgSeowHCRJHYaDJKnDcJAkdRgOkqQOw0GS1GE4SJI6eg2HJL+bpJKsbKeT5M+TbE3y+STH9lmfJE2q3sIhyeHAycC1A7OfCBzZ/jsL+OseSpOkiddny+ENwEuBGph3OvD2anwS2D/JIb1UJ0kTrJdwSHIa8PWq+txOiw4DrhuY3tbOkyQtoalRvXGSjwAHD1n0cuD3gF8Y9rIh82rIPJKcRdP1xIMf/OB7WKUkaZiRhUNVPX7Y/CQ/BRwBfC4JwBrgM0mOo2kpHD6w+hrg+nnefwOwAWB6enpogEiS7pkl71aqqi9U1UFVtbaq1tIEwrFVdSOwCXhWe9TS8cAtVXXDUtcoSZNuZC2He+iDwJOArcB3gWf3W44kTabew6FtPcw9L+D5/VUjSQLPkJYkDWE4SJI6DAdJUofhIEnqMBz2sO3b4cQTYb/9msft2/uuSJLuPsNhD1u/Hi69FHbsaB7Xr++7Ikm6+wyHPWzLFpidbZ7PzjbTkjRuDIc9bN06mGrPHpmaaqYladwYDnvYxo1wwgmwYkXzuHFj3xVJ0t3X+xnS9zarV8Mll/RdhSTtHlsOkqQOw0GS1GE4SJI6DAdJUofhIEnqMBwkSR2GgySpw3CQJHUYDpKkjt7CIckLk3w5yRVJXjcw/5wkW9tlT+irPkmaZL1cPiPJzwOnA8dU1W1JDmrnHwWcAfwkcCjwkSSPqKof9lGnJE2qvloOzwNeU1W3AVTVTe3804F3VdVtVXUNsBU4rqcaJWli9RUOjwB+NsllSf4lyaPb+YcB1w2st62d15HkrCSbk2yemZkZcbmSNFlG1q2U5CPAwUMWvbzd7gHA8cCjgQuSPBTIkPVr2PtX1QZgA8D09PTQdSRJ98zIwqGqHj/fsiTPA95TVQV8KskdwEqalsLhA6uuAa4fVY2SpOH66lZ6H/A4gCSPAO4HfAPYBJyR5P5JjgCOBD7VU42SNLH6utnPW4C3JPkicDtwZtuKuCLJBcCVwCzwfI9UkqSl10s4VNXtwDPmWfYq4FVLW5EkaZBnSEuSOgwHSVLHRIfD9u1w4omw337N4/btfVckScvDRIfD+vVw6aWwY0fzuH593xVJ0vIw0eGwZQvMzjbPZ2ebaUnShIfDunUw1R6vNTXVTEuSJjwcNm6EE06AFSuax40b+65IkpaHvk6CWxZWr4ZLLum7Cklafia65SBJGs5wkCR1GA6SpA7DQZLUYThIkjoMB0lSR5rbKIy3JDPA13bjLVbS3GxoHFjraFjr6IxTvZNW60OqatWwBfeKcNhdSTZX1XTfdSyGtY6GtY7OONVrrT9it5IkqcNwkCR1GA6NDX0XcDdY62hY6+iMU73W2nLMQZLUYctBktRhOEiSOiY6HJKckuTLSbYmeVnf9cwnyeFJPpbkqiRXJHlx3zUtJMleST6b5B/7rmUhSfZPcmGSL7U/4xP6rmk+SX67/Rv4YpJ3Jtm775rmJHlLkpuSfHFg3oFJLkpydft4QJ81Dpqn3j9p/w4+n+S9Sfbvs8Y5w2odWPa7SSrJyj25zYkNhyR7AX8JPBE4CnhakqP6rWpes8BLquqRwPHA85dxrXNeDFzVdxGL9Ebgn6rqJ4CfZpnWneQw4EXAdFUdDewFnNFvVXfxNuCUnea9DLi4qo4ELm6nl4u30a33IuDoqjoG+ApwzlIXNY+30a2VJIcDJwPX7ukNTmw4AMcBW6vqq1V1O/Au4PSeaxqqqm6oqs+0z3fQfHkd1m9V80uyBngy8Oa+a1lIkv2AE4FzAarq9qq6ud+qdmkKeECSKWAf4Pqe67lTVV0CfHOn2acD57XPzwOeuqRF7cKweqvqw1XV3lmeTwJrlrywIeb52QK8AXgpsMePLJrkcDgMuG5gehvL+At3TpK1wKOAy/qtZJf+jOYP9o6+C1mEhwIzwFvbbrA3J9m376KGqaqvA39Ks5d4A3BLVX2436oWtLqqboBmJwc4qOd67o7nAB/qu4j5JDkN+HpVfW4U7z/J4ZAh85b1cb1JHgi8G/itqrq173qGSXIqcFNVXd53LYs0BRwL/HVVPQr4Dsur6+NObX/96cARwKHAvkme0W9V905JXk7TnXt+37UMk2Qf4OXA749qG5McDtuAwwem17CMmug7S3JfmmA4v6re03c9u/BY4LQk/0HTVfe4JO/ot6Rd2gZsq6q5ltiFNGGxHD0euKaqZqrqB8B7gMf0XNNCtic5BKB9vKnnehaU5EzgVODptXxPBHsYzU7C59r/a2uAzyQ5eE9tYJLD4dPAkUmOSHI/moG9TT3XNFSS0PSJX1VVr++7nl2pqnOqak1VraX5mX60qpbt3m1V3Qhcl+TH21knAVf2WNKuXAscn2Sf9m/iJJbp4PmATcCZ7fMzgff3WMuCkpwCnA2cVlXf7bue+VTVF6rqoKpa2/5f2wYc2/497xETGw7toNMLgH+m+Q92QVVd0W9V83os8EyavfAt7b8n9V3UvcgLgfOTfB5YB/xxz/UM1bZuLgQ+A3yB5v/vsrncQ5J3ApcCP55kW5LnAq8BTk5yNc1RNa/ps8ZB89T7F8AK4KL2/9nf9Fpka55aR7vN5dtqkiT1ZWJbDpKk+RkOkqQOw0GS1GE4SJI6DAdJUofhIEnqMBwkSR2GgzQCSR7d3hNg7yT7tvdgOLrvuqTF8iQ4aUSS/BGwN/AAmus3vbrnkqRFMxykEWmv2fVp4PvAY6rqhz2XJC2a3UrS6BwIPJDmWj3L5nae0mLYcpBGJMkmmsuWHwEcUlUv6LkkadGm+i5AujdK8ixgtqr+vr1f+SeSPK6qPtp3bdJi2HKQJHU45iBJ6jAcJEkdhoMkqcNwkCR1GA6SpA7DQZLUYThIkjr+Pw5PPY71Yn0iAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "means = [[ 32.98489643 -57.51107027]]\n",
      "covariances = [[429.45764867  90.24987882]]\n",
      "weights =  [[0.86682762 0.13317238]]\n"
     ]
    }
   ],
   "source": [
    "for i in range(0, len(labels)):\n",
    "    if labels[i] == 0:\n",
    "        plt.scatter(i, data.take(i), s=15, c='red')\n",
    "    elif labels[i] == 1:\n",
    "        plt.scatter(i, data.take(i), s=15, c='blue')\n",
    "plt.title('Gaussian Mixture Model')\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y')\n",
    "plt.show()\n",
    "print(\"means =\", gmmModel.means_.reshape(1, -1))\n",
    "print(\"covariances =\", gmmModel.covariances_.reshape(1, -1))\n",
    "print(\"weights = \", gmmModel.weights_.reshape(1, -1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 习题9.4\n",
    "&emsp;&emsp;EM算法可以用到朴素贝叶斯法的非监督学习，试写出其算法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**解答：** \n",
    "> **EM算法的一般化：**  \n",
    "**E步骤：**根据参数初始化或上一次迭代的模型参数来计算出隐变量的后验概率，其实就是隐变量的期望。作为隐变量的现估计值：$$w_j^{(i)}=Q_{i}(z^{(i)}=j) := p(z^{(i)}=j | x^{(i)} ; \\theta)$$\n",
    "**M步骤：**将似然函数最大化以获得新的参数值：$$\n",
    "\\theta :=\\arg \\max_{\\theta} \\sum_i \\sum_{z^{(i)}} Q_i (z^{(i)}) \\log \\frac{p(x^{(i)}, z^{(i)} ; \\theta)}{Q_i (z^{(i)})}\n",
    "$$  \n",
    "\n",
    "使用NBMM（朴素贝叶斯的混合模型）中的$\\phi_z,\\phi_{j|z^{(i)}=1},\\phi_{j|z^{(i)}=0}$参数替换一般化的EM算法中的$\\theta$参数，然后依次求解$w_j^{(i)}$与$\\phi_z,\\phi_{j|z^{(i)}=1},\\phi_{j|z^{(i)}=0}$参数的更新问题。  \n",
    "**NBMM的EM算法：**  \n",
    "**E步骤：**  \n",
    "$$w_j^{(i)}:=P\\left(z^{(i)}=1 | x^{(i)} ; \\phi_z, \\phi_{j | z^{(i)}=1}, \\phi_{j | z^{(i)}=0}\\right)$$**M步骤：**$$\n",
    "\\phi_{j | z^{(i)}=1} :=\\frac{\\displaystyle \\sum_{i=1}^{m} w^{(i)} I(x_{j}^{(i)}=1)}{\\displaystyle \\sum_{i=1}^{m} w^{(i)}} \\\\ \n",
    "\\phi_{j | z^{(i)}=0}:= \\frac{\\displaystyle  \\sum_{i=1}^{m}\\left(1-w^{(i)}\\right) I(x_{j}^{(i)}=1)}{ \\displaystyle \\sum_{i=1}^{m}\\left(1-w^{(i)}\\right)} \\\\ \n",
    "\\phi_{z^{(i)}} :=\\frac{\\displaystyle \\sum_{i=1}^{m} w^{(i)}}{m} \n",
    "$$   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "参考代码：https://github.com/wzyonggege/statistical-learning-method\n",
    "\n",
    "本文代码更新地址：https://github.com/fengdu78/lihang-code\n",
    "\n",
    "习题解答：https://github.com/datawhalechina/statistical-learning-method-solutions-manual\n",
    "\n",
    "中文注释制作：机器学习初学者公众号：ID:ai-start-com\n",
    "\n",
    "配置环境：python 3.5+\n",
    "\n",
    "代码全部测试通过。\n",
    "![gongzhong](../gongzhong.jpg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
